import os
import pprint
import numpy as np
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter

from tianshou.env import DummyVectorEnv
from tianshou.data import Collector
from tianshou.policy import RandomPolicy
from tianshou.utils import BasicLogger

from tic_tac_toe_env import TicTacToeEnv
from tic_tac_toe import get_parser, get_agents, train_agent, watch


def get_args():
    parser = get_parser()
    parser.add_argument('--self_play_round', type=int, default=20)
    args = parser.parse_known_args()[0]
    return args


def gomoku(args=get_args()):
    Collector._default_rew_metric = lambda x: x[args.agent_id - 1]
    if args.watch:
        watch(args)
        return

    policy, optim = get_agents(args)
    agent_learn = policy.policies[args.agent_id - 1]
    agent_opponent = policy.policies[2 - args.agent_id]

    # log
    log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
    writer = SummaryWriter(log_path)
    args.logger = BasicLogger(writer)

    opponent_pool = [agent_opponent]

    def env_func():
        return TicTacToeEnv(args.board_size, args.win_size)
    test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
    for r in range(args.self_play_round):
        rews = []
        agent_learn.set_eps(0.0)
        # compute the reward over previous learner
        for opponent in opponent_pool:
            policy.replace_policy(opponent, 3 - args.agent_id)
            test_collector = Collector(policy, test_envs)
            results = test_collector.collect(n_episode=100)
            rews.append(results['rews'].mean())
        rews = np.array(rews)
        # weight opponent by their difficulty level
        rews = np.exp(-rews * 10.0)
        rews /= np.sum(rews)
        total_epoch = args.epoch
        args.epoch = 1
        for epoch in range(total_epoch):
            # sample one opponent
            opp_id = np.random.choice(len(opponent_pool), size=1, p=rews)
            print(f'selection probability {rews.tolist()}')
            print(f'selected opponent {opp_id}')
            opponent = opponent_pool[opp_id.item(0)]
            agent = RandomPolicy()
            # previous learner can only be used for forward
            agent.forward = opponent.forward
            args.model_save_path = os.path.join(
                args.logdir, 'Gomoku', 'dqn',
                f'policy_round_{r}_epoch_{epoch}.pth')
            result, agent_learn = train_agent(
                args, agent_learn=agent_learn,
                agent_opponent=agent, optim=optim)
            print(f'round_{r}_epoch_{epoch}')
            pprint.pprint(result)
        learnt_agent = deepcopy(agent_learn)
        learnt_agent.set_eps(0.0)
        opponent_pool.append(learnt_agent)
        args.epoch = total_epoch
    if __name__ == '__main__':
        # Let's watch its performance!
        opponent = opponent_pool[-2]
        watch(args, agent_learn, opponent)


if __name__ == '__main__':
    gomoku(get_args())
